use std::collections::HashSet;

use squawk_syntax::{
    Parse, SourceFile,
    ast::{self, AstNode},
    identifier::Identifier,
};

use crate::{Edit, Fix, Linter, Rule, Violation};

use lazy_static::lazy_static;

use crate::visitors::{check_not_allowed_types, is_not_valid_int_type};

lazy_static! {
    static ref SERIAL_TYPES: HashSet<Identifier> = HashSet::from([
        Identifier::new("serial"),
        Identifier::new("serial2"),
        Identifier::new("serial4"),
        Identifier::new("serial8"),
        Identifier::new("smallserial"),
        Identifier::new("bigserial"),
    ]);
}

fn replace_serial(serial_type: &str) -> &'static str {
    match serial_type.to_lowercase().as_str() {
        "serial" | "serial4" => "integer generated by default as identity",
        "serial2" | "smallserial" => "smallint generated by default as identity",
        "serial8" | "bigserial" => "bigint generated by default as identity",
        _ => "integer generated by default as identity",
    }
}

fn create_identity_fix(ty: &ast::Type) -> Option<Fix> {
    let type_name = ty.syntax().first_token()?;
    let text = replace_serial(type_name.text());
    let edit = Edit::replace(ty.syntax().text_range(), text);
    Some(Fix::new("Replace with IDENTITY column", vec![edit]))
}

fn check_ty_for_serial(ctx: &mut Linter, ty: Option<ast::Type>) {
    if let Some(ty) = ty {
        if is_not_valid_int_type(&ty, &SERIAL_TYPES) {
            let fix = create_identity_fix(&ty);

            ctx.report(
                Violation::for_node(
                    Rule::PreferIdentity,
                    "Serial types make schema, dependency, and permission management difficult."
                        .into(),
                    ty.syntax(),
                )
                .help("Use an `IDENTITY` column instead.")
                .fix(fix),
            );
        };
    }
}

pub(crate) fn prefer_identity(ctx: &mut Linter, parse: &Parse<SourceFile>) {
    let file = parse.tree();
    check_not_allowed_types(ctx, &file, check_ty_for_serial);
}

#[cfg(test)]
mod test {
    use insta::{assert_debug_snapshot, assert_snapshot};

    use crate::{
        Rule,
        test_utils::{fix_sql, lint},
    };

    fn fix(sql: &str) -> String {
        fix_sql(sql, Rule::PreferIdentity)
    }

    #[test]
    fn fix_serial_types() {
        assert_snapshot!(fix("create table users (id serial);"), @"create table users (id integer generated by default as identity);");
        assert_snapshot!(fix("create table users (id serial2);"), @"create table users (id smallint generated by default as identity);");
        assert_snapshot!(fix("create table users (id serial4);"), @"create table users (id integer generated by default as identity);");
        assert_snapshot!(fix("create table users (id serial8);"), @"create table users (id bigint generated by default as identity);");
        assert_snapshot!(fix("create table users (id smallserial);"), @"create table users (id smallint generated by default as identity);");
        assert_snapshot!(fix("create table users (id bigserial);"), @"create table users (id bigint generated by default as identity);");
    }

    #[test]
    fn fix_mixed_case() {
        assert_snapshot!(fix("create table users (id BIGSERIAL);"), @"create table users (id bigint generated by default as identity);");
        assert_snapshot!(fix("create table users (id Serial);"), @"create table users (id integer generated by default as identity);");
    }

    #[test]
    fn err() {
        let sql = r#"
create table users (
    id serial
);
create table users (
    id serial2
);
create table users (
    id serial4
);
create table users (
    id serial8
);
create table users (
    id smallserial
);
create table users (
    id bigserial
);
create table users (
    id BIGSERIAL
);
        "#;
        let errors = lint(sql, Rule::PreferIdentity);
        assert_ne!(errors.len(), 0);
        assert_eq!(errors.len(), 7);
        assert_eq!(
            errors
                .iter()
                .filter(|x| x.code == Rule::PreferIdentity)
                .count(),
            7
        );
        assert_debug_snapshot!(errors);
    }

    #[test]
    fn ok_when_quoted() {
        let sql = r#"
create table users (
    id "serial"
);
create table users (
    id "bigserial"
);
        "#;
        let errors = lint(sql, Rule::PreferIdentity);
        assert_eq!(errors.len(), 2);
        assert_debug_snapshot!(errors);
    }

    #[test]
    fn ok() {
        let sql = r#"
create table users (
    id  bigint generated by default as identity primary key
);
create table users (
    id  bigint generated always as identity primary key
);
        "#;
        let errors = lint(sql, Rule::PreferIdentity);
        assert_eq!(errors.len(), 0);
    }
}
